import numpy as np
# from bartpy.extensions.baseestimator import ResidualBART
from sklearn.linear_model import LinearRegression
from causally.model.abstract_model import SKAbstractModel
from sklearn.neural_network import MLPRegressor
from causalml.inference.meta import BaseXRegressor, BaseRRegressor, BaseSRegressor, BaseTRegressor
from causalml.inference.meta import XGBTRegressor, MLPTRegressor, LRSRegressor, XGBRRegressor


class BART(SKAbstractModel):
    def __init__(self, config,dataset):

        super(BART, self).__init__(config,dataset)
        self.n_jobs = config['n_jobs']
        self.n_trees = config['n_trees']
        self.n_units = dataset.get_X_size()[0]
        # self.model = ResidualBART(
        #     base_estimator= LinearRegression(),n_jobs=self.n_jobs,
        #     n_trees = self.n_trees
        # )
        # # self.model = CausalBart()


        self.model = BaseRRegressor(learner=MLPRegressor())

    def calculate_loss(self, x,t,y,w):

        self.model.fit(X=x,treatment=t,y=y)

    def predict(self, x,t_0,t_1):

        _, yhat_cs, yhat_ts = self.model.predict(x,return_components=True)
        return yhat_ts[1]-yhat_cs[1]
